
from monai.networks.nets.unetr import UNETR
import torch.nn as nn 
import torch 

class UNETRDiff(nn.Module):
    def __init__(self) -> None:
        pass 
        self.model = UNETR(1, 2, 
                           img_size=(128, 128, 128), 
                           feature_size=16,
                           hidden_size=768,
                           mlp_dim=3072,
                           num_heads=12,
                           pos_embed="conv")

    def forward(self, x):
        pass 
        